package com.xnck.mfpms.dao;


import com.xnck.mfpms.entity.QueryPart;
import org.beetl.core.Configuration;
import org.beetl.core.GroupTemplate;
import org.beetl.core.Template;
import org.beetl.core.resource.StringTemplateResourceLoader;
import org.nutz.dao.Dao;
import org.nutz.dao.Sqls;
import org.nutz.dao.entity.Entity;
import org.nutz.dao.sql.Sql;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SqlTool {

    /**
     * 自定义GroupTemplate
     */
    private static GroupTemplate gt;

    /**
     * 拦截并按照数据规则创造新的SQL
     * @param dao
     * @param sqlName
     * @param curUserId
     * @return
     */
    public static Sql injectSearchSql(Dao dao, String sqlName, String curUserId){
        Map<String, Object> map = getActiveQueryParts(dao, sqlName, curUserId);
        Sql sql = dao.sqls().create(sqlName);
        String oldSQL = sql.getSourceSql();
        Template template = gt().getTemplate(oldSQL);
        template.binding(map);
        String newSQL = template.render().replace('\t', ' ').replace('\r', ' ').replace('\n', ' ').trim();;
        sql.setSourceSql(newSQL);
        return sql;
    }

    /**
     * 获得人员可以使用的SQL语句的条件部分
     * @param dao
     * @param queryName
     * @param curUserId
     * @return
     */
    private static Map<String, Object> getActiveQueryParts(Dao dao, String queryName, String curUserId){
        List<QueryPart> roleParts = getQueryPartsFromSql(dao, "injectDao.getQueryPartByRole", queryName, curUserId);
        List<QueryPart> userParts = getQueryPartsFromSql(dao, "injectDao.getQueryPartByUser", queryName, curUserId);
        Map<String, Object> map = new HashMap<String, Object>();
        for (QueryPart queryPart : roleParts){
            map.put(queryPart.getPartname(), true);
        }
        for (QueryPart queryPart : userParts){
            if (!map.containsKey(queryPart.getPartname())){
                map.put(queryPart.getPartname(), true);
            }
        }
        return map;
    }

    private static List<QueryPart> getQueryPartsFromSql(Dao dao, String sqlName, String queryName, String curUserId){
        Sql sql = dao.sqls().create(sqlName);
        sql.params().set("userId", curUserId).set("queryName", queryName);
        sql.setCallback(Sqls.callback.entities());
        Entity<QueryPart> entity = dao.getEntity(QueryPart.class);
        sql.setEntity(entity);
        dao.execute(sql);
        return sql.getList(QueryPart.class);
    }

    /**
     * 获取GroupTemplate
     *
     * @return GroupTemplate实例,如果没有自定义,就生成一个默认的
     */
    public static GroupTemplate gt() {
        if (gt == null) {
            StringTemplateResourceLoader resourceLoader = new StringTemplateResourceLoader();
            Configuration cfg;
            try {
                cfg = Configuration.defaultConfiguration();
                cfg.setStatementStart("#");
                cfg.setStatementEnd(null);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
            gt = new GroupTemplate(resourceLoader, cfg);
        }
        return gt;
    }
}
